{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
    "os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/home/husein/t5/prepare/mesolitica-tpu.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://f000.backblazeb2.com/file/malaya-model/bert-bahasa/seq2seq-base-500k-09-09-2020.tar.gz\n",
    "# !tar -xf seq2seq-base-500k-09-09-2020.tar.gz\n",
    "# !rm seq2seq-base-500k-09-09-2020.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.data_generators import problem\n",
    "from tensor2tensor.data_generators import text_problems\n",
    "from tensor2tensor.data_generators import translate\n",
    "from tensor2tensor.utils import registry\n",
    "from tensor2tensor import problems\n",
    "import tensorflow as tf\n",
    "import os\n",
    "import logging\n",
    "\n",
    "logger = logging.getLogger()\n",
    "tf.logging.set_verbosity(tf.logging.DEBUG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "\n",
    "vocab = 'sp10m.cased.t5.model'\n",
    "sp = spm.SentencePieceProcessor()\n",
    "sp.Load(vocab)\n",
    "\n",
    "class Encoder:\n",
    "    def __init__(self, sp):\n",
    "        self.sp = sp\n",
    "        self.vocab_size = sp.GetPieceSize() + 100\n",
    "    \n",
    "    def encode(self, s):\n",
    "        return self.sp.EncodeAsIds(s)\n",
    "    \n",
    "    def decode(self, ids, strip_extraneous=False):\n",
    "        return self.sp.DecodeIds(list(ids))\n",
    "    \n",
    "encoder = Encoder(sp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from glob import glob\n",
    "\n",
    "@registry.register_problem\n",
    "class Seq2Seq(text_problems.Text2TextProblem):\n",
    "\n",
    "    @property\n",
    "    def approx_vocab_size(self):\n",
    "        return 32100\n",
    "    \n",
    "    @property\n",
    "    def is_generate_per_split(self):\n",
    "        return False\n",
    "            \n",
    "    def feature_encoders(self, data_dir):\n",
    "        encoder = Encoder(sp)\n",
    "        return {\n",
    "            \"inputs\": encoder,\n",
    "            \"targets\": encoder\n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROBLEM = 'seq2_seq'\n",
    "t2t_problem = problems.problem(PROBLEM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "\n",
    "ckpt_path = 'base/model.ckpt-500000'\n",
    "# ckpt_path = tf.train.latest_checkpoint('gs://mesolitica-tpu-general/t2t-base/')\n",
    "# ckpt_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from tensor2tensor import models\n",
    "from tensor2tensor import problems\n",
    "from tensor2tensor.layers import common_layers\n",
    "from tensor2tensor.utils import trainer_lib\n",
    "from tensor2tensor.utils import t2t_model\n",
    "from tensor2tensor.utils import registry\n",
    "from tensor2tensor.utils import metrics\n",
    "from tensor2tensor.data_generators import problem\n",
    "from tensor2tensor.data_generators import text_problems\n",
    "from tensor2tensor.data_generators import translate\n",
    "from tensor2tensor.utils import registry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def input_fn_builder(\n",
    "    input_files,\n",
    "    max_seq_length_encoder,\n",
    "    max_seq_length_decoder,\n",
    "    is_training,\n",
    "    num_cpu_threads = 4,\n",
    "):\n",
    "\n",
    "    data_fields = {\n",
    "        'inputs': tf.VarLenFeature(tf.int64),\n",
    "        'targets': tf.VarLenFeature(tf.int64),\n",
    "    }\n",
    "    data_len = {\n",
    "        'inputs': max_seq_length_encoder,\n",
    "        'targets': max_seq_length_decoder,\n",
    "    }\n",
    "\n",
    "    def parse(serialized_example):\n",
    "\n",
    "        features = tf.parse_single_example(\n",
    "            serialized_example, features = data_fields\n",
    "        )\n",
    "        for k in features.keys():\n",
    "            features[k] = features[k].values\n",
    "            features[k] = tf.pad(\n",
    "                features[k], [[0, data_len[k] - tf.shape(features[k])[0]]]\n",
    "            )\n",
    "            features[k].set_shape((data_len[k]))\n",
    "\n",
    "        return features\n",
    "\n",
    "    def input_fn(params):\n",
    "        batch_size = params['batch_size']\n",
    "\n",
    "        if is_training:\n",
    "            d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))\n",
    "            d = d.repeat()\n",
    "            d = d.shuffle(buffer_size = len(input_files))\n",
    "            cycle_length = min(num_cpu_threads, len(input_files))\n",
    "            d = d.apply(\n",
    "                tf.contrib.data.parallel_interleave(\n",
    "                    tf.data.TFRecordDataset,\n",
    "                    sloppy = is_training,\n",
    "                    cycle_length = cycle_length,\n",
    "                )\n",
    "            )\n",
    "            d = d.shuffle(buffer_size = 100)\n",
    "        else:\n",
    "            d = tf.data.TFRecordDataset(input_files)\n",
    "            d = d.repeat()\n",
    "        d = d.map(parse, num_parallel_calls = 32)\n",
    "        d = d.apply(\n",
    "            tf.contrib.data.map_and_batch(\n",
    "                lambda record: _decode_record(record, data_fields),\n",
    "                batch_size = batch_size,\n",
    "                num_parallel_batches = num_cpu_threads,\n",
    "                drop_remainder = True,\n",
    "            )\n",
    "        )\n",
    "        return d\n",
    "\n",
    "    return input_fn\n",
    "\n",
    "def _decode_record(example, name_to_features):\n",
    "    for name in list(example.keys()):\n",
    "        t = example[name]\n",
    "        if t.dtype == tf.int64:\n",
    "            t = tf.to_int32(t)\n",
    "        example[name] = t\n",
    "\n",
    "    return example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "251"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_files = tf.gfile.Glob('gs://mesolitica-tpu-general/t2t/data/seq2*')\n",
    "len(input_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor 'IteratorGetNext_1:0' shape=(2, 1024) dtype=int32>,\n",
       " <tf.Tensor 'IteratorGetNext_1:1' shape=(2, 1024) dtype=int32>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_input_fn = input_fn_builder(\n",
    "    input_files = input_files,\n",
    "    max_seq_length_encoder = 1024,\n",
    "    max_seq_length_decoder = 1024,\n",
    "    is_training = True,\n",
    ")\n",
    "dataset = train_input_fn({'batch_size': 2})\n",
    "dataset = dataset._make_one_shot_iterator().get_next()\n",
    "X = dataset['inputs']\n",
    "Y = dataset['targets']\n",
    "X, Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(self, X, Y, HPARAMS = \"transformer_base\", DATA_DIR = 't2t/data'):\n",
    "        \n",
    "        self.X = X\n",
    "        self.Y = Y\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        maxlen_decode = tf.reduce_max(self.X_seq_len)\n",
    "        \n",
    "        x = tf.expand_dims(tf.expand_dims(self.X, -1), -1)\n",
    "        y = tf.expand_dims(tf.expand_dims(self.Y, -1), -1)\n",
    "        \n",
    "        features = {\n",
    "            \"inputs\": x,\n",
    "            \"targets\": y,\n",
    "            \"target_space_id\": tf.constant(1, dtype=tf.int32),\n",
    "        }\n",
    "        self.features = features\n",
    "        \n",
    "        Modes = tf.estimator.ModeKeys\n",
    "        hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)\n",
    "        hparams.filter_size = 3072\n",
    "        hparams.hidden_size = 768\n",
    "        hparams.num_heads = 12\n",
    "        hparams.num_hidden_layers = 8\n",
    "        hparams.vocab_divisor = 128\n",
    "        hparams.label_smoothing = 0.0\n",
    "        hparams.shared_embedding_and_softmax_weights = False\n",
    "        hparams.dropout = 0.1\n",
    "        hparams.max_length = 1024\n",
    "        hparams.multiproblem_mixing_schedule = \"pretrain\"\n",
    "\n",
    "        hparams.optimizer = \"Adafactor\"\n",
    "        hparams.learning_rate_warmup_steps = 10000\n",
    "        hparams.learning_rate_schedule = \"rsqrt_decay\"\n",
    "        \n",
    "        translate_model = registry.model('transformer')(hparams, Modes.TRAIN)\n",
    "        self.translate_model = translate_model\n",
    "        logits, _ = translate_model(features)\n",
    "        self.logits = logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StudentModel:\n",
    "    def __init__(self, X, Y, HPARAMS = \"transformer_base\", DATA_DIR = 't2t/data'):\n",
    "        \n",
    "        with tf.compat.v1.variable_scope('student') as vs:\n",
    "        \n",
    "            self.X = X\n",
    "            self.Y = Y\n",
    "\n",
    "            self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "            maxlen_decode = tf.reduce_max(self.X_seq_len)\n",
    "\n",
    "            x = tf.expand_dims(tf.expand_dims(self.X, -1), -1)\n",
    "            y = tf.expand_dims(tf.expand_dims(self.Y, -1), -1)\n",
    "\n",
    "            features = {\n",
    "                \"inputs\": x,\n",
    "                \"targets\": y,\n",
    "                \"target_space_id\": tf.constant(1, dtype=tf.int32),\n",
    "            }\n",
    "            self.features = features\n",
    "\n",
    "            Modes = tf.estimator.ModeKeys\n",
    "            hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)\n",
    "            hparams.filter_size = 1200\n",
    "            hparams.hidden_size = 312\n",
    "            hparams.num_heads = 12\n",
    "            hparams.num_hidden_layers = 4\n",
    "            hparams.vocab_divisor = 128\n",
    "            hparams.label_smoothing = 0.0\n",
    "            hparams.shared_embedding_and_softmax_weights = False\n",
    "            hparams.dropout = 0.1\n",
    "            hparams.max_length = 1024\n",
    "            hparams.multiproblem_mixing_schedule = \"pretrain\"\n",
    "\n",
    "            hparams.optimizer = \"Adafactor\"\n",
    "            hparams.learning_rate_warmup_steps = 10000\n",
    "            hparams.learning_rate_schedule = \"rsqrt_decay\"\n",
    "\n",
    "            translate_model = registry.model('transformer')(hparams, Modes.TRAIN)\n",
    "            self.translate_model = translate_model\n",
    "            logits, _ = translate_model(features)\n",
    "            self.logits = logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'train'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'train'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function framework at 0x7f45d8c712f0> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: LIVE_VARS_IN\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function framework at 0x7f45d8c712f0> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: LIVE_VARS_IN\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Entity <function framework at 0x7f45d8c712f0> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: LIVE_VARS_IN\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/t2t_model.py:2253: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/t2t_model.py:2253: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32128_768.bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32128_768.bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32128_768.targets_bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32128_768.targets_bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/models/transformer.py:95: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/models/transformer.py:95: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function attention_bias_to_padding at 0x7f45d8258510> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function attention_bias_to_padding at 0x7f45d8258510> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Entity <function attention_bias_to_padding at 0x7f45d8258510> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: \n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/expert_utils.py:621: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/expert_utils.py:621: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function layers at 0x7f45d8794048> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: If not matching a CFG node, must be a block statement: <gast.gast.ImportFrom object at 0x7f45cce74dd8>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function layers at 0x7f45d8794048> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: If not matching a CFG node, must be a block statement: <gast.gast.ImportFrom object at 0x7f45cce74dd8>\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Entity <function layers at 0x7f45d8794048> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: If not matching a CFG node, must be a block statement: <gast.gast.ImportFrom object at 0x7f45cce74dd8>\n",
      "INFO:tensorflow:Transforming body output with symbol_modality_32128_768.top\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_32128_768.top\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'train'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'train'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32128_312.bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32128_312.bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32128_312.targets_bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32128_312.targets_bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_32128_312.top\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_32128_312.top\n"
     ]
    }
   ],
   "source": [
    "model = Model(X, Y)\n",
    "student = StudentModel(X, Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def padded_cross_entropy_loss(logits, labels, smoothing = 0.0, vocab_size = 32128):\n",
    "    with tf.name_scope('loss'):\n",
    "\n",
    "        if labels is not None:\n",
    "            with tf.name_scope('smoothing_cross_entropy'):\n",
    "                confidence = 1.0 - smoothing\n",
    "                vocab_float = tf.cast(vocab_size - 1, tf.float32)\n",
    "                low_confidence = (1.0 - confidence) / vocab_float\n",
    "                soft_targets = tf.one_hot(\n",
    "                    labels,\n",
    "                    depth = vocab_size,\n",
    "                    on_value = confidence,\n",
    "                    off_value = low_confidence,\n",
    "                )\n",
    "                xentropy = tf.nn.softmax_cross_entropy_with_logits(\n",
    "                    logits = logits, labels = soft_targets\n",
    "                )\n",
    "\n",
    "                normalizing_constant = -(\n",
    "                    confidence * tf.math.log(confidence)\n",
    "                    + vocab_float\n",
    "                    * low_confidence\n",
    "                    * tf.math.log(low_confidence + 1e-20)\n",
    "                )\n",
    "                xentropy -= normalizing_constant\n",
    "\n",
    "            weights = tf.cast(tf.not_equal(labels, 0), tf.float32)\n",
    "            return tf.reduce_sum(xentropy * weights), weights\n",
    "\n",
    "        else:\n",
    "            loss = tf.constant(0.0)\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor 'strided_slice:0' shape=(2, 1024, 32128) dtype=float32>"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student_logits = student.logits[:,:,0,0]\n",
    "student_logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-23-6025d23f9e25>:16: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See `tf.nn.softmax_cross_entropy_with_logits_v2`.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-23-6025d23f9e25>:16: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See `tf.nn.softmax_cross_entropy_with_logits_v2`.\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor 'loss/Sum:0' shape=() dtype=float32>,\n",
       " <tf.Tensor 'loss/Cast:0' shape=(2, 1024) dtype=float32>)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student_task_xent, weights = padded_cross_entropy_loss(student_logits, student.Y)\n",
    "student_task_xent, weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "distill_temperature = 1.0\n",
    "teacher_targets = tf.nn.softmax(model.logits[:,:,0,0] / distill_temperature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor 'Sum:0' shape=() dtype=float32>"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student_distill_xent = tf.nn.softmax_cross_entropy_with_logits_v2(\n",
    "    labels=tf.stop_gradient(teacher_targets),\n",
    "    logits=student_logits / distill_temperature)\n",
    "student_distill_xent = tf.reduce_sum(student_distill_xent * weights)\n",
    "student_distill_xent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_distill_xent *= distill_temperature**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_balance = 0.5\n",
    "phase_loss = task_balance * student_task_xent\n",
    "phase_loss += (1 - task_balance) * student_distill_xent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = phase_loss / tf.reduce_sum(weights)\n",
    "task_loss = student_task_xent / tf.reduce_sum(weights)\n",
    "distill_loss = student_distill_xent / tf.reduce_sum(weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "tf.train.create_global_step()\n",
    "global_step = tf.train.get_global_step()\n",
    "learning_rate_warmup_steps = 10000\n",
    "learning_rate = tf.rsqrt(tf.maximum(tf.to_float(global_step), learning_rate_warmup_steps))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import adafactor\n",
    "\n",
    "optimizer = adafactor.AdafactorOptimizer(\n",
    "    learning_rate = learning_rate,\n",
    "    decay_rate = adafactor.adafactor_decay_rate_pow(0.8),\n",
    "    beta1 = 0.0,\n",
    ")\n",
    "train_op = optimizer.minimize(loss, global_step = global_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from base/model.ckpt-500000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from base/model.ckpt-500000\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = [v for v in tf.trainable_variables() if 'student/' not in v.name]\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, ckpt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[6.537529, 6.585449, 6.4896092, None]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess.run([loss, task_loss, distill_loss, train_op])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'student/model.ckpt'"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var_lists = [v for v in tf.trainable_variables() if 'student/' in v.name]\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.save(sess, 'student/model.ckpt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
